/*-------------------------------------------------------------------------
 *
 * database.c
 *    Commands to interact with the database object in a distributed
 *    environment.
 *
 * Copyright (c) Citus Data, Inc.
 *
 *-------------------------------------------------------------------------
 */

#include "postgres.h"

#include "miscadmin.h"

#include "access/htup.h"
#include "access/xact.h"
#include "catalog/objectaddress.h"
#include "catalog/pg_database.h"
#include "commands/dbcommands.h"
#include "nodes/parsenodes.h"
#include "utils/syscache.h"

#include "distributed/commands.h"
#include "distributed/commands/utility_hook.h"
#include "distributed/deparser.h"
#include "distributed/metadata/distobject.h"
#include "distributed/metadata_sync.h"
#include "distributed/metadata_utility.h"
#include "distributed/multi_executor.h"
#include "distributed/relation_access_tracking.h"
#include "distributed/worker_transaction.h"

static AlterOwnerStmt* RecreateAlterDatabaseOwnerStmt(Oid databaseOid);
static Oid get_database_owner(Oid db_oid);
static ObjectAddress* GetDatabaseAddressFromDatabaseName(char* databaseName,
                                                         bool missingOk);
List* PreprocessGrantOnDatabaseStmt(Node* node, const char* queryString,
                                    ProcessUtilityContext processUtilityContext);

/* controlled via GUC */
bool EnableAlterDatabaseOwner = true;

/*
 * AlterDatabaseOwnerObjectAddress returns the ObjectAddress of the database that is the
 * object of the AlterOwnerStmt. Errors if missing_ok is false.
 */
List* AlterDatabaseOwnerObjectAddress(Node* node, bool missing_ok, bool isPostprocess)
{
    AlterOwnerStmt* stmt = castNode(AlterOwnerStmt, node);
    Assert(stmt->objectType == OBJECT_DATABASE);

    Oid databaseOid = get_database_oid(strVal(linitial(stmt->object)), missing_ok);
    ObjectAddress* address = static_cast<ObjectAddress*>(palloc0(sizeof(ObjectAddress)));
    ObjectAddressSet(*address, DatabaseRelationId, databaseOid);

    return list_make1(address);
}

/*
 * DatabaseOwnerDDLCommands returns a list of sql statements to idempotently apply a
 * change of the database owner on the workers so that the database is owned by the same
 * user on all nodes in the cluster.
 */
List* DatabaseOwnerDDLCommands(const ObjectAddress* address)
{
    Node* stmt = (Node*)RecreateAlterDatabaseOwnerStmt(address->objectId);
    return list_make1(DeparseTreeNode(stmt));
}

/*
 * RecreateAlterDatabaseOwnerStmt creates an AlterOwnerStmt that represents the operation
 * of changing the owner of the database to its current owner.
 */
static AlterOwnerStmt* RecreateAlterDatabaseOwnerStmt(Oid databaseOid)
{
    AlterOwnerStmt* stmt = makeNode(AlterOwnerStmt);

    stmt->objectType = OBJECT_DATABASE;
    stmt->object = list_make1(makeString(get_database_name(databaseOid)));

    Oid ownerOid = get_database_owner(databaseOid);
#ifdef DISABLE_OG_COMMENTS
    stmt->newowner = makeNode(RoleSpec);
    stmt->newowner->roletype = ROLESPEC_CSTRING;
#endif
    stmt->newowner = GetUserNameFromId(ownerOid);

    return stmt;
}

/*
 * get_database_owner returns the Oid of the role owning the database
 */
static Oid get_database_owner(Oid db_oid)
{
    HeapTuple tuple = SearchSysCache1(DATABASEOID, ObjectIdGetDatum(db_oid));
    if (!HeapTupleIsValid(tuple)) {
        ereport(ERROR, (errcode(ERRCODE_UNDEFINED_DATABASE),
                        errmsg("database with OID %u does not exist", db_oid)));
    }

    Oid dba = ((Form_pg_database)GETSTRUCT(tuple))->datdba;

    ReleaseSysCache(tuple);

    return dba;
}

/*
 * PreprocessGrantOnDatabaseStmt is executed before the statement is applied to the local
 * postgres instance.
 *
 * In this stage we can prepare the commands that need to be run on all workers to grant
 * on databases.
 */
List* PreprocessGrantOnDatabaseStmt(Node* node, const char* queryString,
                                    ProcessUtilityContext processUtilityContext)
{
    if (!ShouldPropagateDDL()) {
        return NIL;
    }

    GrantStmt* stmt = castNode(GrantStmt, node);
    Assert(stmt->objtype == ACL_OBJECT_DATABASE);

    List* databaseList = stmt->objects;

    if (list_length(databaseList) == 0) {
        return NIL;
    }

    EnsureCoordinator();

    char* sql = DeparseTreeNode((Node*)stmt);

    List* commands = list_make3(void_cast(DISABLE_DDL_PROPAGATION), (void*)sql,
                                void_cast(ENABLE_DDL_PROPAGATION));

    return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
}

/*
 * PreprocessAlterDatabaseStmt is executed before the statement is applied to the local
 * postgres instance.
 *
 * In this stage we can prepare the commands that need to be run on all workers to grant
 * on databases.
 */
List* PreprocessAlterDatabaseStmt(Node* node, const char* queryString,
                                  ProcessUtilityContext processUtilityContext)
{
    AlterDatabaseStmt* stmt = castNode(AlterDatabaseStmt, node);
    bool missingOk = false;
    ObjectAddress* dbAddress =
        GetDatabaseAddressFromDatabaseName(stmt->dbname, missingOk);
    if (!ShouldPropagate() || !IsAnyObjectDistributed(list_make1(dbAddress))) {
        return NIL;
    }

    EnsureCoordinator();

    char* sql = DeparseTreeNode((Node*)stmt);

    List* commands = list_make3(void_cast(DISABLE_DDL_PROPAGATION), (void*)sql,
                                void_cast(ENABLE_DDL_PROPAGATION));

    return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
}

/*
 * PreprocessAlterDatabaseSetStmt is executed before the statement is applied to the local
 * postgres instance.
 *
 * In this stage we can prepare the commands that need to be run on all workers to grant
 * on databases.
 */
List* PreprocessAlterDatabaseRefreshCollStmt(Node* node, const char* queryString,
                                             ProcessUtilityContext processUtilityContext)
{
    if (!ShouldPropagate()) {
        return NIL;
    }
#ifdef DISABLE_OG_COMMENTS
    AlterDatabaseRefreshCollStmt* stmt = castNode(AlterDatabaseRefreshCollStmt, node);

    EnsureCoordinator();

    char* sql = DeparseTreeNode((Node*)stmt);

    List* commands = list_make3(void_cast(DISABLE_DDL_PROPAGATION), (void*)sql,
                                void_cast(ENABLE_DDL_PROPAGATION));

    return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
#else
    return NIL;
#endif
}

/*
 * GetDatabaseAddressFromDatabaseName gets the database name and returns the ObjectAddress
 * of the database.
 */
static ObjectAddress* GetDatabaseAddressFromDatabaseName(char* databaseName,
                                                         bool missingOk)
{
    Oid databaseOid = get_database_oid(databaseName, missingOk);
    ObjectAddress* dbObjectAddress =
        static_cast<ObjectAddress*>(palloc0(sizeof(ObjectAddress)));
    ObjectAddressSet(*dbObjectAddress, DatabaseRelationId, databaseOid);
    return dbObjectAddress;
}
